Skip to content

Add vmap for BroadcastAxes#3344

Merged
angeloskath merged 1 commit intomainfrom
baxes-vmap
Apr 1, 2026
Merged

Add vmap for BroadcastAxes#3344
angeloskath merged 1 commit intomainfrom
baxes-vmap

Conversation

@angeloskath
Copy link
Copy Markdown
Member

A simplified and correct version of #3319. The main point is that it moves the batch axis first and takes advantage of the BroadcastAxes primitive's negative axes handling (which was made for this tbh) to avoid all axis math. Unfortunately we may need to expand to move the vectorized axis first if the dims are smaller.

In addition to the above it makes sure that the BroadcastAxes primitive is used even when it is a noop to have a consistent graph that works all the time. Basically dealing with the broadcast version of #3202.

Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice implementation 👍

@angeloskath angeloskath merged commit 1944cf6 into main Apr 1, 2026
16 checks passed
@angeloskath angeloskath deleted the baxes-vmap branch April 1, 2026 00:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants